// Copyright 2018 Google LLC All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transport

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"net/url"
	"strings"
	"testing"

	"github.com/google/go-containerregistry/pkg/authn"
	"github.com/google/go-containerregistry/pkg/name"
)

var (
	testReference, _ = name.NewTag("localhost:8080/user/image:latest", name.StrictValidation)
)

func TestTransportNoActionIfTransportIsAlreadyWrapper(t *testing.T) {
	server := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			w.Header().Set("WWW-Authenticate", `Bearer realm="http://foo.io"`)
			http.Error(w, "Should not contact the server", http.StatusBadRequest)
		}))
	defer server.Close()
	tprt := &http.Transport{
		Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
	}

	wTprt := &Wrapper{inner: tprt}

	if _, err := NewWithContext(context.Background(), testReference.Context().Registry, nil, wTprt, []string{testReference.Scope(PullScope)}); err != nil {
		t.Errorf("NewWithContext unexpected error %s", err)
	}
}

func TestTransportSelectionAnonymous(t *testing.T) {
	// Record the requests we get in the inner transport.
	cannedResponse := http.Response{
		Status:     http.StatusText(http.StatusOK),
		StatusCode: http.StatusOK,
		Body:       io.NopCloser(strings.NewReader("")),
	}
	recorder := newRecorder(&cannedResponse, nil)

	basic := &authn.Basic{Username: "foo", Password: "bar"}
	reg := testReference.Context().Registry

	tp, err := NewWithContext(context.Background(), reg, basic, recorder, []string{testReference.Scope(PullScope)})
	if err != nil {
		t.Errorf("NewWithContext() = %v", err)
	}

	req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/v2/anything", reg), nil)
	if err != nil {
		t.Fatalf("Unexpected error during NewRequest: %v", err)
	}
	if _, err := tp.RoundTrip(req); err != nil {
		t.Fatalf("Unexpected error during RoundTrip: %v", err)
	}

	if got, want := len(recorder.reqs), 2; got != want {
		t.Fatalf("expected %d requests, got %d", want, got)
	}
	recorded := recorder.reqs[1]
	if got, want := recorded.URL.Scheme, "https"; got != want {
		t.Errorf("wrong scheme, want %s got %s", want, got)
	}
}

func TestTransportSelectionBasic(t *testing.T) {
	server := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			w.Header().Set("WWW-Authenticate", `Basic`)
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
		}))
	defer server.Close()
	tprt := &http.Transport{
		Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
	}

	basic := &authn.Basic{Username: "foo", Password: "bar"}

	tp, err := NewWithContext(context.Background(), testReference.Context().Registry, basic, tprt, []string{testReference.Scope(PullScope)})
	if err != nil {
		t.Errorf("NewWithContext() = %v", err)
	}
	if tpw, ok := tp.(*Wrapper); !ok {
		t.Errorf("NewWithContext(); got %T, want *Wrapper", tp)
	} else if _, ok := tpw.inner.(*basicTransport); !ok {
		t.Errorf("NewWithContext(); got %T, want *basicTransport", tp)
	}
}

type badAuth struct{}

func (a *badAuth) Authorization() (*authn.AuthConfig, error) {
	return nil, errors.New("sorry dave, I'm afraid I can't let you do that")
}

func TestTransportBadAuth(t *testing.T) {
	server := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			w.Header().Set("WWW-Authenticate", `Bearer realm="http://foo.io"`)
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
		}))
	defer server.Close()
	tprt := &http.Transport{
		Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
	}

	if _, err := NewWithContext(context.Background(), testReference.Context().Registry, &badAuth{}, tprt, []string{testReference.Scope(PullScope)}); err == nil {
		t.Errorf("NewWithContext() expected err, got nil")
	}
}

func TestTransportSelectionBearer(t *testing.T) {
	request := 0
	server := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			request++
			switch request {
			case 1:
				// This is an https request that fails, causing us to fall back to http.
				http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
			case 2:
				w.Header().Set("WWW-Authenticate", `Bearer realm="http://foo.io"`)
				http.Error(w, "Unauthorized", http.StatusUnauthorized)
			case 3:
				hdr := r.Header.Get("Authorization")
				if !strings.HasPrefix(hdr, "Basic ") {
					t.Errorf("Header.Get(Authorization); got %v, want Basic prefix", hdr)
				}
				if got, want := r.FormValue("scope"), testReference.Scope(PullScope); got != want {
					t.Errorf("FormValue(scope); got %v, want %v", got, want)
				}
				// Check that the service isn't set (we didn't specify it above)
				// https://github.com/google/go-containerregistry/issues/1359
				if got, want := r.FormValue("service"), ""; got != want {
					t.Errorf("FormValue(service); got %q, want %q", got, want)
				}
				w.Write([]byte(`{"token": "dfskdjhfkhsjdhfkjhsdf"}`))
			}
		}))
	defer server.Close()
	tprt := &http.Transport{
		Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
	}

	basic := &authn.Basic{Username: "foo", Password: "bar"}
	tp, err := NewWithContext(context.Background(), testReference.Context().Registry, basic, tprt, []string{testReference.Scope(PullScope)})
	if err != nil {
		t.Errorf("NewWithContext() = %v", err)
	}
	if tpw, ok := tp.(*Wrapper); !ok {
		t.Errorf("NewWithContext(); got %T, want *Wrapper", tp)
	} else if _, ok := tpw.inner.(*bearerTransport); !ok {
		t.Errorf("NewWithContext(); got %T, want *bearerTransport", tp)
	}
}

func TestTransportSelectionBearerMissingRealm(t *testing.T) {
	server := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			w.Header().Set("WWW-Authenticate", `Bearer service="gcr.io"`)
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
		}))
	defer server.Close()
	tprt := &http.Transport{
		Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
	}

	basic := &authn.Basic{Username: "foo", Password: "bar"}
	tp, err := NewWithContext(context.Background(), testReference.Context().Registry, basic, tprt, []string{testReference.Scope(PullScope)})
	if err == nil || !strings.Contains(err.Error(), "missing realm") {
		t.Errorf("NewWithContext() = %v, %v", tp, err)
	}
}

func TestTransportSelectionBearerAuthError(t *testing.T) {
	request := 0
	server := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			request++
			switch request {
			case 1:
				w.Header().Set("WWW-Authenticate", `Bearer realm="http://foo.io"`)
				http.Error(w, "Unauthorized", http.StatusUnauthorized)
			case 2:
				http.Error(w, "Oops", http.StatusInternalServerError)
			}
		}))
	defer server.Close()
	tprt := &http.Transport{
		Proxy: func(*http.Request) (*url.URL, error) { return url.Parse(server.URL) },
	}

	basic := &authn.Basic{Username: "foo", Password: "bar"}
	tp, err := NewWithContext(context.Background(), testReference.Context().Registry, basic, tprt, []string{testReference.Scope(PullScope)})
	if err == nil {
		t.Errorf("NewWithContext() = %v", tp)
	}
}

func TestTransportAlwaysTriesHttps(t *testing.T) {
	// Use a NewTLSServer so that this speaks TLS even though it's localhost.
	// This ensures that we try https even for local registries.
	count := 0
	server := httptest.NewTLSServer(
		http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
			count++
			w.Write([]byte(`{"token": "dfskdjhfkhsjdhfkjhsdf"}`))
		}))
	defer server.Close()

	u, err := url.Parse(server.URL)
	if err != nil {
		t.Errorf("Unexpected error during url.Parse: %v", err)
	}
	registry, err := name.NewRegistry(u.Host, name.WeakValidation)
	if err != nil {
		t.Errorf("Unexpected error during NewRegistry: %v", err)
	}

	basic := &authn.Basic{Username: "foo", Password: "bar"}
	tp, err := NewWithContext(context.Background(), registry, basic, server.Client().Transport, []string{testReference.Scope(PullScope)})
	if err != nil {
		t.Fatalf("NewWithContext() = %v, %v", tp, err)
	}
	if count == 0 {
		t.Errorf("failed to call TLS localhost server")
	}
}
